Skip to content

Feat: Handle Adjoints through Initialization #1168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 91 commits into from
Apr 24, 2025
Merged

Conversation

DhairyaLGandhi
Copy link
Member

@DhairyaLGandhi DhairyaLGandhi commented Mar 6, 2025

Checklist

  • Appropriate tests were added
  • Any code changes were done in a way that does not break public API
  • All documentation related to code changes were updated
  • The new code follows the
    contributor guidelines, in particular the SciML Style Guide and
    COLPRAC.
  • Any new documentation only uses public API

Additional context

MTK and SciML construct an initialization problem before starting the time stepping to ensure the starting values of the unknowns and parameters adhere to any constraints needed for the system. This PR adds handling for adjoint sensitivities of the NonlinearProblem, NonlinearSquaresProblem, SCCNonlinearProblem etc.

I am opening this to get some feedback regarding how we can accumulate gradients correctly. I have also included a test case for a DAE which I will update to use the values out of SciMLSensitivity.

Add any other context about the problem here.

Currently the gradients get calculated but don't get accumulated, we need to be able to update the gradients for the parameters. Since this is a manual dispatch, the usual graph building in AD is bypassed, and we need to handle this manually. Ideally, we should make it so the cfg itself includes the initialization so we would not have gotten incorrect gradients in the first place 😅 We are also forced to use a LinearProblem instead of \ because it cannot handle singular jacobians.

cc @ChrisRackauckas

@DhairyaLGandhi
Copy link
Member Author

I wanted to ask whether it is preferred to retain \ for the smaller problems, which would be slower with LinearSolve, and if so, we will need to deal with LAPACK errors.

@ChrisRackauckas
Copy link
Member

I wanted to ask whether it is preferred to retain \ for the smaller problems, which would be slower with LinearSolve, and if so, we will need to deal with LAPACK errors.

LinearSolve.jl should be faster across the board? It depends a bit on the CPU architecture since it depends on whether it guesses the right LU correctly,

@ChrisRackauckas
Copy link
Member

Note that with the latest MTK update, there is now an Initials section where the initial u0 live. That should fix a few things.

@@ -425,6 +425,21 @@ function DiffEqBase._concrete_solve_adjoint(
save_end = true, kwargs_fwd...)
end

# Get gradients for the initialization problem if it exists
igs = if _prob.f.initialization_data.initializeprob != nothing
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be before the solve, since you can use the initialization solution from here in the remakes of 397-405 in order to set new u0 and p and thus skip running the initialization a second time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can I indicate to solve to avoid running initialization?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

initializealg = NoInit(). Should probably just do CheckInit() for safety but either is fine.

@@ -103,15 +102,18 @@ end
else
if linsolve === nothing && isempty(sensealg.linsolve_kwargs)
# For the default case use `\` to avoid any form of unnecessary cache allocation
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I don't know about that comment. I think it's just old. (a) \ always allocates because it uses lu instead of lu!, so it's re-allocating the while matrix which is larger than any LinearSolve allocation, and (b) we have since 2023 setup tests on StaticArrays, so the immutable path is non-allocating. I don't think (b) was true when this was written.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So glad we can remove this branch altogether.

iprob = _prob.f.initialization_data.initializeprob
ip = parameter_values(iprob)
itunables, irepack, ialiases = canonicalize(Tunable(), ip)
igs, = Zygote.gradient(ip) do ip
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gradient isn't used? I think this would go into the backpass and if I'm thinking clearly, the resulting return is dp .* igs?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet. These gradients are currently against the parameters of the initialization problem, not the system exactly. And the mapping between the two is ill defined, so we cannot simply accum

I spoke with @AayushSabharwal about a way to map, it seems initialization_data.intializeprobmap might have some support to return the correctly shaped vector, but there are cases where we cannot know the ordering of dp either.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's another subtlety. I am not sure we haven't missed some part of the cfg by manually handling accumulation of gradients. Or any transforms we might need to calculate gradients for. The regular AD graph building typically took care of these details for us, but in this case we would need to worry about incorrect gradients manually

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, you need to use the initializeprobmap https://github.com/SciML/SciMLBase.jl/blob/master/src/initialization.jl#L268 to map it back to the shape of the initial parameters.

but there are cases where we cannot know the ordering of dp either.

p and dp just need the same ordering, so initializeprobmap should do the trick.

There's another subtlety. I am not sure we haven't missed some part of the cfg by manually handling accumulation of gradients. Or any transforms we might need to calculate gradients for. The regular AD graph building typically took care of these details for us, but in this case we would need to worry about incorrect gradients manually

This is the only change to (u0,p) before solving, so this would account for it, given initializeprobmap is just an index map so an identity function.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed this occurance in 95ebbf3 to check if this is correct. Will need to work around the global call

@DhairyaLGandhi
Copy link
Member Author

Trying to use the initialization end to end caused gradients against parameters to get dropped. https://github.com/DhairyaLGandhi/SciMLBase.jl/tree/dg/nonlinear is a WIP branch which adds adjoints to the getindex calls, which does capture the expected gradients, but we still end up dropping gradients somewhere in the chain. I am looking into whether we are doing so in the custom adjoints, because that was an issue I had identified for the ODE case.

new_u0, new_p, _ = SciMLBase.get_initial_values(new_prob, new_prob, new_prob.f, SciMLBase.OverrideInit(), Val(true);
abstol = 1e-6,
reltol = 1e-6,
sensealg = SteadyStateAdjoint(autojacvec = ZygoteVJP()))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't default to ZygoteVJP. Should use the autojacvec of the ODE

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 9a8a845

@DhairyaLGandhi
Copy link
Member Author

I could use some understanding of how to handle initialization when MTK analytically solves the problem, and removes all the unknowns. In that case u0, u etc are empty, and we attempt to calculate the jacobian for it in dgdgu_val which seems weird, there are no unknowns to calculate the jacobian for. It would be equally incorrect to use the unknowns from the system itself. The best way to handle this might be to simply add a check for isempty and return. Is that reasonable?

@DhairyaLGandhi
Copy link
Member Author

DhairyaLGandhi commented Apr 23, 2025

There seems to be an InitialFailure with DefaultInit

@DhairyaLGandhi
Copy link
Member Author

julia> prob, ps, init = setups[3];

julia> init
OrdinaryDiffEqCore.DefaultInit()

julia> new_sol = solve(prob, Rodas5P(); p = p, sensealg, initializealg = init, abstol = 1e-6, reltol = 1e-3)
┌ Warning: Initialization system is overdetermined. 1 equations for 0 unknowns. Initialization will default to using least squares. `SCCNonlinearProblem` can only be used for initialization of fully determined systems and hence will not be used here. To suppress this warning pass warn_initialize_determined = false. To make this warning into an error, pass fully_determined = true
└ @ ModelingToolkit ~/.julia/packages/ModelingToolkit/aau6A/src/systems/diffeqs/abstractodesystem.jl:1520
retcode: InitialFailure
Interpolation: specialized 4rd order "free" stiffness-aware interpolation
t: 1-element Vector{Float64
}:
 0.0
u: 1-element Vector{Vector{Float64}}:
 [2.0, 0.0, 0.0, 1.0, 0.0]

w2 => -1.0,]
prob_correctu0 = ODEProblem(sys, u0_correct, tspan, p, jac = true, guesses = [w2 => -1.0])
mtkparams_correctu0 = SciMLSensitivity.parameter_values(prob_correctu0)
prob_correctu0.u0[5] = -1.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AayushSabharwal note that this is the issue of not passing over the initial values and guess values to u0. Without this CheckInit will fail. It's a minor issue but it is something we should handle over time.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'm working on using the hook in DiffEqBase. It's easy enough to implement, but somehow the returned problem's u0 isn't being respected

Comment on lines 752 to 764
# if !hasmethod(Zygote.adjoint,
# Tuple{Zygote.AContext, typeof(Zygote.literal_getproperty),
# SciMLBase.AbstractTimeseriesSolution, Val{:u}})
# Zygote.@adjoint function Zygote.literal_getproperty(sol::AbstractTimeseriesSolution,
# ::Val{:u})
# function solu_adjoint(Δ)
# zerou = zero(sol.prob.u0)
# _Δ = @. ifelse(Δ === nothing, (zerou,), Δ)
# (SciMLBase.build_solution(sol.prob, sol.alg, sol.t, _Δ),)
# end
# sol.u, solu_adjoint
# end
# end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dead code?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a mistake pushing cdaa2c7 so have reverted it

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies

Project.toml Outdated
@@ -1,7 +1,7 @@
name = "SciMLSensitivity"
uuid = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
authors = ["Christopher Rackauckas <[email protected]>", "Yingbo Ma <[email protected]>"]
version = "7.77.0"
version = "7.76.0"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah this was picked from the history, updated. Thanks

@DhairyaLGandhi DhairyaLGandhi merged commit 5f6ab43 into master Apr 24, 2025
14 of 18 checks passed
@ChrisRackauckas ChrisRackauckas deleted the dg/initprob branch April 24, 2025 22:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants